#!/bin/bash


eval "$(conda shell.bash hook)"
conda activate jax
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_PYTHON_CLIENT_ALLOCATOR=platform

# Changeable parameters
export GEMMA_MODEL_NAME_SHORT="gemma-2-2b"
export TYPE="both"
export TEMPERATURE=1.0
export MAX_NEW_TOKENS=100

export SAE_MODEL_PATH="$HOME/$GEMMA_MODEL_NAME_SHORT-sae/k5_final_sae_model.pkl"
export SAE_CODE_PATH="$HOME/$GEMMA_MODEL_NAME_SHORT-sae/k5_whole_sae_final_z.npy"
export MLP_MODEL_PATH="$HOME/dual-map/model/$GEMMA_MODEL_NAME_SHORT/dual_map_${GEMMA_MODEL_NAME_SHORT}.pt"
export OUTPUT_DIR="$HOME/data/generate_sentences/$GEMMA_MODEL_NAME_SHORT"
mkdir -p $OUTPUT_DIR
export GEMMA_MODEL_NAME="google/$GEMMA_MODEL_NAME_SHORT"

# Generate samples
python $HOME/src/eval/sae-softmax/eval_sae_aware_generation.py \
    --method $TYPE \
    --dataset_name "$HOME/data/original_sentence.jsonl" \
    --output_dir $OUTPUT_DIR \
    --model_name $GEMMA_MODEL_NAME \
    --sae_model_path $SAE_MODEL_PATH \
    --sae_code_path $SAE_CODE_PATH \
    --mlp_model_path $MLP_MODEL_PATH \
    --temperature $TEMPERATURE \
    --max_new_tokens $MAX_NEW_TOKENS
